import tensorflow as tf

# ******* equal + argmax + cast + reduce_sum ***********
sample_num = 5
pred = tf.random.normal([sample_num, 10], stddev=1, mean=0)
true_label = tf.random.uniform([sample_num], minval=0, maxval=10, dtype=tf.int64)
pred = tf.argmax(pred, axis=1)
print("pred:", pred.numpy())
print("true_label:", true_label.numpy())
equal_a = tf.equal(pred, true_label)
equal_b = tf.cast(equal_a, dtype=tf.int32)
correct = tf.reduce_sum(equal_b)
print("accuracy:", correct/sample_num)

